# Copyright 2024 PKU-Alignment Team and Lagent Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from copy import copy
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn

import torch


class LMTemplateParser:
    """Intermidate prompt template parser, specifically for language models.

    Args:
        meta_template (list of dict, optional): The meta template for the
            model.
    """

    def __init__(self, meta_template: Optional[List[Dict]] = None):
        self.meta_template = meta_template
        if meta_template:
            assert isinstance(meta_template, list)
            self.roles: Dict[str, dict] = dict()  # maps role name to config
            for item in meta_template:
                assert isinstance(item, dict)
                assert item['role'] not in self.roles, 'role in meta prompt must be unique!'
                self.roles[item['role']] = item.copy()

    def __call__(self, dialog) -> str:
        """Parse a prompt template, and wrap it with meta template if
        applicable.

        Args:
            dialog (List[str or PromptList]): A prompt
                template (potentially before being wrapped by meta template).

        Returns:
            str: The final string.
        """
        assert isinstance(dialog, (str, list))
        if isinstance(dialog, str):
            return dialog
        if self.meta_template:
            prompt = ''
            for index, item in enumerate(dialog):
                if isinstance(item, str):
                    prompt += item
                else:
                    new_str = self._prompt2str(item, index == len(dialog) - 1)
                    prompt += new_str
        else:
            # in case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in dialog:
                if isinstance(item, str):
                    if item:
                        prompt += last_sep + item
                elif item.get('content', ''):
                    prompt += last_sep + item.get('prompt', '')
                last_sep = '\n'
        return prompt

    def _prompt2str(self, prompt: Union[str, Dict], last: bool = False) -> Tuple[str, bool]:
        if isinstance(prompt, str):
            return prompt
        role_cfg = self.roles.get(prompt['role'])

        begin = role_cfg.get('begin', '')
        res = begin

        # A tool call is used in this turn
        if prompt.get('role') == 'assistant' and prompt.get('name', None):
            res += prompt.get('content', '') + role_cfg.get('end_of_message', '')

        res += prompt.get('content', '') + role_cfg.get('end_of_turn', '')
        if last:
            pass
        if last and role_cfg['role'] != 'assistant':
            res += self.roles['assistant'].get('begin', '')
            return res
        return res


class BaseModel:
    """Base class for model wrapper.

    Args:
        path (str): The path to the model.
        max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
            to 512.
        tokenizer_only (bool): If True, only the tokenizer will be initialized.
            Defaults to False.
        meta_template (list of dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
    """

    is_api: bool = False

    def __init__(
        self,
        path: str,
        tokenizer_only: bool = False,
        template_parser: 'LMTemplateParser' = LMTemplateParser,
        meta_template: Optional[List[Dict]] = None,
        *,
        max_new_tokens: int = 512,
        top_p: float = 0.8,
        top_k: float = 40,
        temperature: float = 0.8,
        repetition_penalty: float = 1.0,
        stop_words: Union[List[str], str] = None,
    ):
        self.path = path
        self.tokenizer_only = tokenizer_only
        # meta template
        self.template_parser = template_parser(meta_template)
        self.eos_token_id = None
        if meta_template and 'eos_token_id' in meta_template:
            self.eos_token_id = meta_template['eos_token_id']

        if isinstance(stop_words, str):
            stop_words = [stop_words]
        self.gen_params = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_words=stop_words,
        )

    def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
        """Generate results given a str (or list of) inputs.

        Args:
            inputs (Union[str, List[str]]):
            gen_params (dict): The input params for generation.

        Returns:
            Union[str, List[str]]: A (list of) generated strings.

        eg.
            batched = True
            if isinstance(inputs, str):
                inputs = [inputs]
                batched = False
            response = ['']
            if batched:
                return response
            return response[0]
        """
        raise NotImplementedError

    def stream_generate(self, inputs: str, **gen_params) -> List[str]:
        """Generate results as streaming given a str inputs.

        Args:
            inputs (str):
            gen_params (dict): The input params for generation.

        Returns:
            str: A generated string.
        """
        raise NotImplementedError

    def chat(self, inputs: Union[List[dict], List[List[dict]]], **gen_params):
        """Generate completion from a list of templates.

        Args:
            inputs (Union[List[dict], List[List[dict]]]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        if isinstance(inputs[0], list):
            _inputs = list()
            for msg in inputs:
                _inputs.append(self.template_parser(msg))
        else:
            _inputs = self.template_parser(inputs)
        return self.generate(_inputs, **gen_params)

    def generate_from_template(self, inputs: Union[List[dict], List[List[dict]]], **gen_params):
        warn(
            'This function will be deprecated after three months'
            'and will be replaced. Please use `.chat()`',
            DeprecationWarning,
            2,
        )
        return self.chat(inputs, **gen_params)

    def stream_chat(self, inputs: List[dict], **gen_params):
        """Generate results as streaming given a list of templates.

        Args:
            inputs (Union[List[dict]):
            gen_params (dict): The input params for generation.
        Returns:
        """
        raise NotImplementedError

    def tokenize(self, prompts: Union[str, List[str], List[dict], List[List[dict]]]):
        """Tokenize the input prompts.

        Args:
            prompts(str | List[str]): user's prompt, or a batch prompts

        Returns:
            Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
            ids, ids' length and requested output length
        """
        raise NotImplementedError

    def update_gen_params(self, **kwargs):
        gen_params = copy(self.gen_params)
        gen_params.update(kwargs)
        return gen_params


class BaseDiffusionModelPipeline:
    """Base class for diffusion model wrapper.

    Args:
        path (str): The path to the model.
        dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
    """

    def __init__(
        self,
        model_path: str,
        output_path: str,
        dtype: torch.dtype = torch.bfloat16,
        gen_params: Optional[Dict] = None,
    ) -> None:
        self.model_path = model_path
        self.output_path = output_path
        self.dtype = dtype
        self.gen_params = gen_params or dict()

        self.pipe = self._load_pipe()
        # self.pipe.set_progress_bar_config(disable=None)

    def _load_pipe(self):
        raise NotImplementedError

    def generate(self, prompts: Union[str, List[str]]) -> str:
        """Generate results given a str (or list of) inputs.

        Args:
            prompts (Union[str, List[str]]):

        Returns:
            Union[str, List[str]]: A (list of) output path of Modality.

        """
        raise NotImplementedError

    def update_gen_params(self, **kwargs):
        gen_params = copy(self.gen_params)
        gen_params.update(kwargs)
        return gen_params

    def get_timestamp(self):
        return datetime.now().strftime('%Y%m%d-%H%M%S')
